import os
import numpy
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from matplotlib.pyplot import MultipleLocator
import argparse
import time
import numpy as np
import pandas as pd
from matplotlib.transforms import Bbox
import pdb


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Humanoid-v3')
    parser.add_argument('--epochs', nargs='+', default=[25, 75, 75, 75, 75, 75, 75, 75])
    parser.add_argument('--asymptotic-epochs', type=int, default=250)
    parser.add_argument('--steps-per-epoch', type=int, default=4000)
    parser.add_argument('--prefix', type=str, default='')
    parser.add_argument('--filename', type=str, default='progress.txt')
    parser.add_argument('-k', '--target-keys', nargs='+', default=['AverageTestEpRet'])
    parser.add_argument('--dir-names', nargs='+',
                        default=[
                            ['results/td3'],
                            ['results/sac'],
                            ['results/noda_mlp'],
                            ['results/noda'],
                            ])
    parser.add_argument('--exp-names', nargs='+', default=['TD3', 'SAC', 'AE-SAC', 'NODA-SAC'])
    parser.add_argument('--save-dir', type=str, default='results')
    parser.add_argument('--use-csv', action='store_true', default=False)
    parser.add_argument('--envs', nargs='+', default= ['InvertedPendulum-v2', 'HalfCheetah-v3', 'Hopper-v3',
                                                       'Walker2d-v3', 'Ant-v3', 'Humanoid-v3', 'Swimmer-v3',
                                                       'Thrower-v2'])
    parser.add_argument('--seeds', nargs='+', default=[[0, 1, 2, 3], [0, 2, 3, 5], [1, 2, 3, 5], [0, 1, 2, 3],
                                                       [0, 2, 3, 5], [0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 3, 4]])
    parser.add_argument('--used-envs', nargs='+', default=[1, 2, 3, 5])
    return parser.parse_args()


def extract_data(args, dir_name, target_key, seeds, epochs=50):
    if not os.path.isdir(dir_name):
        return None
    potential_dirs = os.listdir(dir_name)
    target_dirs = []
    for potential_dir in potential_dirs:
        if os.path.isdir(os.path.join(dir_name, potential_dir)):
            target_dirs.append(potential_dir)
    dirs_with_prefix = [target_dir for target_dir in target_dirs if args.prefix in target_dir]
    target_value = []
    if len(dirs_with_prefix) > 0:
        target_dirs = dirs_with_prefix
    for target_dir in target_dirs:
        if int(target_dir[-1]) not in seeds:
            continue
        target_path = os.path.join(dir_name, target_dir, args.filename)
        with open(target_path, 'r') as f:
            data = f.readlines()
        data = [line.strip().split('\t') for line in data]
        keys = np.array(data[0])
        values = np.array(data)[1:]
        values = np.where(values == '', '0', values).astype('float')
        target_key_index = np.where(keys == target_key)[0][0]
        if len(values[:, target_key_index]) >= epochs:
            target_value.append(values[:epochs, target_key_index])
    return np.array(target_value, dtype='float')


def extract_asymptotic_data(args, dir_name, target_key, seeds):
    if not os.path.isdir(dir_name):
        return None
    potential_dirs = os.listdir(dir_name)
    target_dirs = []
    for potential_dir in potential_dirs:
        if os.path.isdir(os.path.join(dir_name, potential_dir)):
            target_dirs.append(potential_dir)
    dirs_with_prefix = [target_dir for target_dir in target_dirs if args.prefix in target_dir]
    target_value = []
    if len(dirs_with_prefix) > 0:
        target_dirs = dirs_with_prefix
    for target_dir in target_dirs:
        if int(target_dir[-1]) not in seeds:
            continue
        target_path = os.path.join(dir_name, target_dir, args.filename)
        with open(target_path, 'r') as f:
            data = f.readlines()
        data = [line.strip().split('\t') for line in data]
        keys = np.array(data[0])
        values = np.array(data)[1:]
        values = np.where(values == '', '0', values).astype('float')
        target_key_index = np.where(keys == target_key)[0][0]
        if len(values[:, target_key_index]) >= args.asymptotic_epochs:
            target_value.append(values[args.asymptotic_epochs - 1, target_key_index])
    if len(target_value) == len(seeds):
        return np.max(target_value)
    else:
        return None


def draw(args=parse_args()):
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    for target_key in args.target_keys:
        fig, axes = plt.subplots(nrows=1, ncols=len(args.used_envs), figsize=(10 * len(args.used_envs), 6))
        for env_index in range(len(args.used_envs)):
            if len(args.used_envs) == 1:
                ax = axes
            else:
                ax = axes[env_index]
            i = args.used_envs[env_index] - 1
            env = args.envs[i]
            print(env)
            exp_results = []
            asymptotic_value = None
            for dir_names in args.dir_names:
                for dir_name in dir_names:
                    data = extract_data(args, os.path.join(dir_name, 'env' + str(get_env_index(env))),
                                        target_key, args.seeds[i], args.epochs[i])
                    if data is not None:
                        exp_results.append(data)
                    else:
                        print('Missing data')
                    temp = extract_asymptotic_data(args, os.path.join(dir_name, 'env' + str(get_env_index(env))),
                                                   target_key, args.seeds[i])
                    if temp is not None:
                        if asymptotic_value is None:
                            asymptotic_value = temp
                        elif temp > asymptotic_value:
                            asymptotic_value = temp
            exp_means = []
            exp_stds = []
            for exp_result in exp_results:
                if len(exp_result) > 0:
                    exp_means.append(np.mean(exp_result, axis=0))
                    exp_stds.append(np.std(exp_result, axis=0))
                else:
                    exp_means.append(None)
                    exp_stds.append(None)
            steps = np.arange(1, args.epochs[i] + 1) * args.steps_per_epoch
            for j in range(len(args.dir_names)):
                if type(exp_means[j]) is np.float64 or type(exp_means[j]) is np.float64:
                    continue
                if exp_means[j] is None or exp_stds[j] is None:
                    continue
                if len(exp_means[j]) < len(steps) or len(exp_stds[j]) < len(steps):
                    continue
                ax.plot(steps, exp_means[j], label=args.exp_names[j], linewidth=3.5)
                ax.fill_between(steps, exp_means[j] - exp_stds[j], exp_means[j] + exp_stds[j], alpha=0.3)
                print(args.exp_names[j], '{:.1f}±{:.1f}'.format(exp_means[j][-1], exp_stds[j][-1]))
            if asymptotic_value is not None:
                ax.plot(steps, np.ones_like(steps) * asymptotic_value, linewidth=3.5, linestyle='--')
            ax.set_xlabel('Environment Steps')
            if target_key == 'AverageTestEpRet':
                ax.set_ylabel('Episode Return')
            else:
                ax.set_ylabel(target_key)

            def format_num(x, _):
                return '%.0fk' % (x / 1000)
            formatter = FuncFormatter(format_num)
            ax.xaxis.set_major_formatter(formatter)
            x_major_locator = MultipleLocator(50000)
            ax.xaxis.set_major_locator(x_major_locator)
            ax.set_title(env)
            ax.grid(True)
            if len(args.used_envs) == 1:
                ax.legend(args.exp_names, loc='best')
        if len(args.used_envs) > 1:
            fig.legend(args.exp_names, bbox_to_anchor=(0.5, -0.15), loc='lower center', ncol=4)
        if len(args.used_envs) > 1:
            save_path_without_extension = os.path.join(args.save_dir, target_key + '_' + str(args.epochs))
            plt.savefig(save_path_without_extension + '.png',
                        bbox_inches=Bbox([[0, -0.9], [10 * len(args.used_envs), 6]]))
            plt.savefig(save_path_without_extension + '.pdf',
                        bbox_inches=Bbox([[0, -0.9], [10 * len(args.used_envs), 6]]))
        else:
            save_path_without_extension = os.path.join(args.save_dir, target_key + '_' + str(args.epochs))
            plt.savefig(save_path_without_extension + '.png')
            plt.savefig(save_path_without_extension + '.pdf')
        plt.close()


def get_env_index(env):
    env_list = ['InvertedPendulum-v2', 'HalfCheetah-v3', 'Hopper-v3', 'Walker2d-v3',
                'Ant-v3', 'Humanoid-v3', 'Swimmer-v3', 'Thrower-v2']
    return env_list.index(env) + 1


def filter_df(df, envs, target_key):
    keys = ['latnoda', 'model_step', 'nodamodelstep', 'updateactionturns']
    values = [32, 1, 1, 5]
    for (key, value) in zip(keys, values):
        if key in df.columns:
            df = df[df[key] == value]
    if 'useode' in df.columns:
        df = df[df['useode'] == (df.iloc[0])['useode']]
    stats = []
    for i in range(0, len(envs)):
        if 'env' in df.columns:
            data_array = np.array((df[df['env'] == get_env_index(envs[i])])[['noise', 's', target_key]])
        else:
            data_array = np.array(df[['noise', 's', target_key]])
        noise_levels = np.unique(data_array[:, 0])
        mean_values = np.zeros(noise_levels.shape)
        std_values = np.zeros(noise_levels.shape)
        target_values = data_array[:, -1]
        for j in range(len(noise_levels)):
            mean_values[j] = np.mean(target_values[np.where(data_array[:, 0] == noise_levels[j])[0]])
            std_values[j] = np.std(target_values[np.where(data_array[:, 0] == noise_levels[j])[0]])
        stats.append(np.stack((noise_levels, mean_values, std_values), axis=1))
    return stats


def draw_csv(args=parse_args()):
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    for target_key in args.target_keys:
        if args.exp_names[0] == 'noda':
            noda_data = pd.read_csv(os.path.join(args.dir_names[0], 'noda.csv'))
            sac_data = filter_df(pd.read_csv(os.path.join(args.dir_names[1], 'sac.csv')), args.envs, target_key)
        else:
            noda_data = pd.read_csv(os.path.join(args.dir_names[1], 'noda.csv'))
            sac_data = filter_df(pd.read_csv(os.path.join(args.dir_names[0], 'sac.csv')), args.envs, target_key)
        try:
            noda_ode_data = filter_df(noda_data[noda_data['useode'] == 1], args.envs, target_key)
            noda_no_ode_data = filter_df(noda_data[noda_data['useode'] == 0], args.envs, target_key)

            for i in range(len(args.envs)):
                fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 5))
                ax.plot(sac_data[i][:, 0], sac_data[i][:, 1], label='SAC')
                ax.fill_between(sac_data[i][:, 0], sac_data[i][:, 1] - sac_data[i][:, 2],
                                sac_data[i][:, 1] + sac_data[i][:, 2], alpha=0.3)
                ax.plot(noda_no_ode_data[i][:, 0], noda_no_ode_data[i][:, 1], label='NODA-MLP-SAC')
                ax.fill_between(noda_no_ode_data[i][:, 0], noda_no_ode_data[i][:, 1] - noda_no_ode_data[i][:, 2],
                                noda_no_ode_data[i][:, 1] + noda_no_ode_data[i][:, 2], alpha=0.3)
                ax.plot(noda_ode_data[i][:, 0], noda_ode_data[i][:, 1], label='NODA-ODE-SAC')
                ax.fill_between(noda_ode_data[i][:, 0], noda_ode_data[i][:, 1] - noda_ode_data[i][:, 2],
                                noda_ode_data[i][:, 1] + noda_ode_data[i][:, 2], alpha=0.3)
                ax.set_xlabel('Noise')
                ax.set_ylabel(target_key)
                ax.legend(loc='best')
                save_path_without_extension = os.path.join(args.save_dir, args.envs[i] + '_' + target_key)
                plt.savefig(save_path_without_extension + '.png')
                plt.savefig(save_path_without_extension + '.pdf')
                plt.close()
        except KeyError:
            noda_data = filter_df(noda_data, args.envs, target_key)

            for i in range(len(args.envs)):
                fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 5))
                ax.plot(sac_data[i][:, 0], sac_data[i][:, 1], label='SAC')
                ax.fill_between(sac_data[i][:, 0], sac_data[i][:, 1] - sac_data[i][:, 2],
                                sac_data[i][:, 1] + sac_data[i][:, 2], alpha=0.3)
                ax.plot(noda_data[i][:, 0], noda_data[i][:, 1], label='NODA-SAC')
                ax.fill_between(noda_data[i][:, 0], noda_data[i][:, 1] - noda_data[i][:, 2],
                                noda_data[i][:, 1] + noda_data[i][:, 2], alpha=0.3)
                ax.set_xlabel('Noise')
                ax.set_ylabel(target_key)
                ax.legend(loc='best')
                save_path_without_extension = os.path.join(args.save_dir, args.envs[i] + '_' + target_key)
                plt.savefig(save_path_without_extension + '.png')
                plt.savefig(save_path_without_extension + '.pdf')
                plt.close()


if __name__ == '__main__':
    plt.rcParams['font.sans-serif'] = ['Times New Roman']
    plt.rcParams.update({'figure.autolayout': True})
    plt.rc('font', size=23)
    draw_args = parse_args()
    if draw_args.use_csv:
        draw_csv(draw_args)
    else:
        draw(draw_args)
